import pandas as pd
import os
import csv
import sys
import re

from surprise import Dataset
from surprise import Reader
from surprise.model_selection import PredefinedKFold

from collections import defaultdict


class YahooDataset:
    ratingsTrainPath = 'R4/ydata-ymovies-user-movie-ratings-train-v1_0.txt'
    ratingsTestPath = 'R4/ydata-ymovies-user-movie-ratings-test-v1_0.txt'
    usersPath = 'R4/ydata-ymovies-user-demographics-v1_0.txt'
    moviesPath = 'R4/movie_db_yoda'
    fullSetPath = 'R4/fullSet.txt'

    # CREATE DATAFRAME
    def loadYahooPandasFullDataFrame(self):
        # os.chdir(os.path.dirname(sys.argv[0]))
        data = pd.read_csv(self.fullSetPath, delimiter='\t', header=None,
                           names=["userId", "movieId", "rating", "rating2"])
        return data

    def loadDemographicsData(self):
        os.chdir(os.path.dirname(sys.argv[0]))
        data = pd.read_csv(self.usersPath, delimiter='\t', header=None,
                           names=["userId", "year", "gender"])
        return data

    def loadYahooPandasTrainingDataFrame(self):
        os.chdir(os.path.dirname(sys.argv[0]))
        data = pd.read_csv(self.ratingsTrainPath, delimiter='\t', header=None,
                           names=["userId", "movieId", "rating_wide", "rating"],
                           usecols=["userId", "movieId", "rating"])
        return data

    def loadYahooPandasTestDataFrame(self):
        os.chdir(os.path.dirname(sys.argv[0]))
        data = pd.read_csv(self.ratingsTestPath, delimiter='\t', header=None,
                           names=["userId", "movieId", "rating_wide", "rating"],
                           usecols=["userId", "movieId", "rating"])
        return data

    # NORMALIZE DATASET
    def loadNormalizedData(self, data=None):
        if data is None:
            data = self.loadYahooPandasTrainingDataFrame()

        normalizedByUser, usersAverage = self.normalizeByUser(data)

        allNormalized, itemsAverage = self.normalizeByItem(normalizedByUser)

        return (allNormalized, usersAverage, itemsAverage)

    def normalizeByUser(self, data=None):
        if data is None:
            data = self.loadYahooPandasTrainingDataFrame()

        # data.rating = data.iloc[:, [0, 2]].set_index("userId").transform(lambda p: p-usersAverage.loc[p.name, "rating"] ,axis=1).reset_index().rating
        # usersAverage =  data.iloc[:, [0, 2]].groupby("userId").mean()

        normalizedByUser = data.set_index("movieId").groupby("userId").transform(lambda p: p - p.mean()).reset_index()
        normalizedByUser.insert(0, "userId", data.userId)

        return normalizedByUser

    def normalizeByItem(self, data=None):
        if data is None:
            data = self.loadYahooPandasTrainingDataFrame()

        # itemsAverage = data.iloc[:, [1,2]].groupby("movieId").mean()
        normalizedByItem = data.set_index("userId").groupby("movieId").transform(lambda p: p - p.mean()).reset_index()
        normalizedByItem.insert(1, "movieId", data.movieId)

        return normalizedByItem

    # USE SUPRISE READER
    def loadFromPandas(self, frame):
        reader = Reader(rating_scale=(frame["rating"].min(), frame["rating"].max()))
        data = Dataset.load_from_df(frame, reader=reader)
        return data

    def loadFullSet(self):
        reader = Reader(line_format='user item rating', sep='\t', skip_lines=0)
        fullDataset = Dataset.load_from_file(self.fullSetPath, reader=reader)

        return fullDataset.build_full_trainset()

    def getTestDataGlobalMean(self):
        testData = self.loadYahooPandasTestDataFrame()
        return testData["rating"].mean()

    def loadYahooDataset(self):

        os.chdir(os.path.dirname(sys.argv[0]))
        print(os.listdir())
        reader = Reader(line_format='user item rating timestamp', sep='\t', skip_lines=0)

        data = Dataset.load_from_folds([(self.ratingsTrainPath, self.ratingsTestPath)], reader=reader)
        pkf = PredefinedKFold()

        return pkf.split(data)

    def loadYahooTrainDataset(self):

        os.chdir(os.path.dirname(sys.argv[0]))
        reader = Reader(line_format='user item rating timestamp', sep='\t', skip_lines=0)
        ratingsTrainDataset = Dataset.load_from_file(self.ratingsTrainPath, reader=reader)

        return ratingsTrainDataset.build_full_trainset()

    def loadYahooTestDataset(self):

        os.chdir(os.path.dirname(sys.argv[0]))
        reader = Reader(line_format='user item rating timestamp', sep='\t', skip_lines=0)
        ratingsTestDataset = Dataset.load_from_file(self.ratingsTestPath, reader=reader)

        return ratingsTestDataset.construct_testset(ratingsTestDataset.raw_ratings)

    def loadMovies(self):
        # movieID_to_info = defaultdict(str)
        movieID_to_info = defaultdict(dict)

        with open(self.moviesPath, newline='') as csvfile:
            movieReader = csv.reader(csvfile, delimiter='\t')
            for row in movieReader:
                movieID = row[0]
                title = row[1]
                genreList = row[10]
                # print(genreList)
                if genreList != "\\N":
                    genreList = genreList.split("|")
                    genreList = ", ".join(genreList)

                actorList = row[16]
                if actorList != "\\N":
                    actorList = actorList.split("|")
                    actorList = ", ".join(actorList)

                synopsis = row[2]

                movieID_to_info[movieID] = {"Title": title,
                                            "Genres": genreList,
                                            "Actors": actorList,
                                            "Synopsis": synopsis}

        return movieID_to_info

    def getGenres(self):
        genres = defaultdict(list)
        genreIDs = {}
        maxGenreID = 0

        with open(self.moviesPath, newline='') as csvfile:
            movieReader = csv.reader(csvfile, delimiter='\t')
            # next(movieReader)  # Skip header line
            for row in movieReader:
                movieID = row[0]
                genreList = row[10]
                genreIDList = []

                if genreList != "\\N":
                    genreList = genreList.split("|")

                # movieID = int(row[0])

                for genre in genreList:
                    if genre in genreIDs:
                        genreID = genreIDs[genre]
                    else:
                        genreID = maxGenreID
                        genreIDs[genre] = genreID
                        maxGenreID += 1
                    genreIDList.append(genreID)

                genres[movieID] = genreIDList

        # Convert integer-encoded genre lists to bitfields that we can treat as vectors
        for (movieID, genreIDList) in genres.items():
            bitfield = [0] * maxGenreID
            for genreID in genreIDList:
                bitfield[genreID] = 1
            genres[movieID] = bitfield

        return genres

    def getYears(self):
        p = re.compile(r"(?:\((\d{4})\))?\s*$")
        years = defaultdict(int)
        with open(self.moviesPath, newline='') as csvfile:
            movieReader = csv.reader(csvfile, delimiter='\t')
            # next(movieReader)
            for row in movieReader:
                movieID = row[0]
                synopsis = row[1]
                m = p.search(synopsis)
                year = m.group(1)
                # print(year)
                if year:
                    years[movieID] = int(year)
        return years


if __name__ == "__main__":
    # random image
    yd = YahooDataset()

    # all movies data -- print
    # movies = yd.loadMovies()

    # get genres -- print
    # genres = yd.getGenres()
    # print(genres['1808406117'])
    # print(genres['1808406133'])
    # print(genres['1808406138'])

    # get years -- print
    years = yd.getYears()
    print(years['1808406117'])

    # df = yd.loadYahooPandasTestDataFrame()
    # print(df.head(10))
